{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5f93b7d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:33.250.525 [mindspore/run_check/_check_version.py:357] MindSpore version 2.3.1 and Ascend AI software package (Ascend Data Center Solution)version 7.5 does not match, the version of software package expect one of ['7.2', '7.3']. Please refer to the match info on: https://www.mindspore.cn/install\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:35.646.327 [mindspore/run_check/_check_version.py:375] MindSpore version 2.3.1 and \"te\" wheel package version 7.5 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:35.649.853 [mindspore/run_check/_check_version.py:382] MindSpore version 2.3.1 and \"hccl\" wheel package version 7.5 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:35.652.151 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 3\n",
      "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:36.654.522 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 2\n",
      "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:48:37.657.277 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 1\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.910 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import mindspore\n",
    "from mindnlp.transformers import AutoModelForSeq2SeqLM\n",
    "from mindnlp.peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType\n",
    "from mindnlp.dataset import load_dataset\n",
    "from mindnlp.core import ops\n",
    "\n",
    "from mindnlp.transformers import AutoTokenizer\n",
    "from mindnlp.common.optimization import get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "\n",
    "model_name_or_path = \"bigscience/mt0-large\"\n",
    "tokenizer_name_or_path = \"bigscience/mt0-large\"\n",
    "\n",
    "checkpoint_name = \"financial_sentiment_analysis_lora_v1.ckpt\"\n",
    "max_length = 128\n",
    "lr = 1e-3\n",
    "num_epochs = 3\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8d0850ac",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "MT5ForConditionalGeneration has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
      "  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
      "  - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB\n",
      "trainable params: 2,359,296 || all params: 1,231,940,608 || trainable%: 0.19151053100118282\n"
     ]
    }
   ],
   "source": [
    "# creating model\n",
    "peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4ee2babf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ccab873e2c5a4eb884364c23445f920f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading builder script:   0%|          | 0.00/6.04k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the module from /home/lvyufeng/.cache/huggingface/modules/datasets_modules/datasets/financial_phrasebank/dcba1275938e14aa0a7f6ba18fef92e9e957c0e5436dc58ad71c41c2906eb4f2 (last modified on Tue Aug 13 21:20:09 2024) since it couldn't be found locally at financial_phrasebank, or remotely on the Hugging Face Hub.\n"
     ]
    }
   ],
   "source": [
    "mindspore.dataset.config.set_seed(123)\n",
    "# loading dataset\n",
    "dataset = load_dataset(\"financial_phrasebank\", \"sentences_allagree\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d1ecd7a0-2e48-405e-940b-001a96ccc18f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['negative', 'neutral', 'positive']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classes = dataset.source.ds.features[\"label\"].names\n",
    "classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8cbbf475-89b6-4368-a7eb-730dc532558e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(4027200:281473120071712,MainProcess):2024-10-21-16:49:36.547.450 [mindspore/dataset/engine/datasets.py:1217] Dataset is shuffled before split.\n"
     ]
    }
   ],
   "source": [
    "train_dataset, validation_dataset = dataset.shuffle(64).split([0.9, 0.1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a89748b4-3193-4419-be39-2dd7907e349e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_text_label(sentence, label):\n",
    "    return sentence, label, classes[label.item()]\n",
    "\n",
    "train_dataset = train_dataset.map(add_text_label, ['sentence', 'label'], ['sentence', 'label', 'text_label'])\n",
    "validation_dataset = validation_dataset.map(add_text_label, ['sentence', 'label'], ['sentence', 'label', 'text_label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "396f46c4-45ba-4d45-a98f-fff173038b5b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sentence': Tensor(shape=[], dtype=String, value= 'The gross area of the Innova 2 project will be about 10,000 sq m ( 107,600 sq ft ) .'),\n",
       " 'label': Tensor(shape=[], dtype=Int64, value= 1),\n",
       " 'text_label': Tensor(shape=[], dtype=String, value= 'neutral')}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(train_dataset.create_dict_iterator())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "603ea9ef-88dd-4c8e-bd10-df7a64940b01",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "86d29e1a-eb31-48bf-a6e3-bc3cfa26cde5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mindnlp.dataset import BaseMapFunction\n",
    "from threading import Lock\n",
    "lock = Lock()\n",
    "\n",
    "class MapFunc(BaseMapFunction):\n",
    "    def __call__(self, sentence, label, text_label):\n",
    "        lock.acquire()\n",
    "        model_inputs = tokenizer(sentence, max_length=max_length, padding=\"max_length\", truncation=True)\n",
    "        labels = tokenizer(text_label, max_length=3, padding=\"max_length\", truncation=True)\n",
    "        lock.release()\n",
    "        labels = labels['input_ids']\n",
    "        labels = np.where(np.equal(labels, tokenizer.pad_token_id), -100, labels)\n",
    "        return model_inputs['input_ids'], model_inputs['attention_mask'], labels\n",
    "\n",
    "\n",
    "def get_dataset(dataset, tokenizer, shuffle=True):\n",
    "    input_colums=['sentence', 'label', 'text_label']\n",
    "    output_columns=['input_ids', 'attention_mask', 'labels']\n",
    "    dataset = dataset.map(MapFunc(input_colums, output_columns),\n",
    "                          input_colums, output_columns)\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(64)\n",
    "    dataset = dataset.batch(batch_size)\n",
    "    return dataset\n",
    "\n",
    "train_dataset = get_dataset(train_dataset, tokenizer)\n",
    "eval_dataset = get_dataset(validation_dataset, tokenizer, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3b955679-aab8-4f82-91cf-06cf7e3db1c0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': Tensor(shape=[8, 128], dtype=Int64, value=\n",
       " [[   486,  33720,   9340 ...      0,      0,      0],\n",
       "  [   259,  85336,    714 ...      0,      0,      0],\n",
       "  [   486,  21511,    345 ...      0,      0,      0],\n",
       "  ...\n",
       "  [169774,    262,    639 ...      0,      0,      0],\n",
       "  [  9981,    259,  45814 ...      0,      0,      0],\n",
       "  [   486,  21982,    345 ...      0,      0,      0]]),\n",
       " 'attention_mask': Tensor(shape=[8, 128], dtype=Int64, value=\n",
       " [[1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  ...\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0]]),\n",
       " 'labels': Tensor(shape=[8, 3], dtype=Int64, value=\n",
       " [[59006,     1,  -100],\n",
       "  [59006,     1,  -100],\n",
       "  [59006,     1,  -100],\n",
       "  ...\n",
       "  [59006,     1,  -100],\n",
       "  [18205,     1,  -100],\n",
       "  [59006,     1,  -100]])}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(train_dataset.create_dict_iterator())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f733a3c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp.core import optim\n",
    "# optimizer and lr scheduler\n",
    "optimizer = optim.AdamW(model.trainable_params(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataset) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b3a4090",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                                       | 0/255 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [05:10<00:00,  1.22s/it]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:11<00:00,  2.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0: train_ppl=Tensor(shape=[], dtype=Float32, value= 1.35495) train_epoch_loss=Tensor(shape=[], dtype=Float32, value= 0.303762) eval_ppl=Tensor(shape=[], dtype=Float32, value= 1.09358) eval_epoch_loss=Tensor(shape=[], dtype=Float32, value= 0.0894594)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [04:42<00:00,  1.11s/it]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:10<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=1: train_ppl=Tensor(shape=[], dtype=Float32, value= 1.06017) train_epoch_loss=Tensor(shape=[], dtype=Float32, value= 0.0584264) eval_ppl=Tensor(shape=[], dtype=Float32, value= 1.03234) eval_epoch_loss=Tensor(shape=[], dtype=Float32, value= 0.0318265)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [04:41<00:00,  1.10s/it]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:10<00:00,  2.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=2: train_ppl=Tensor(shape=[], dtype=Float32, value= 1.03601) train_epoch_loss=Tensor(shape=[], dtype=Float32, value= 0.0353789) eval_ppl=Tensor(shape=[], dtype=Float32, value= 1.0308) eval_epoch_loss=Tensor(shape=[], dtype=Float32, value= 0.0303333)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.core import value_and_grad\n",
    "# training and evaluation\n",
    "def forward_fn(**batch):\n",
    "    outputs = model(**batch)\n",
    "    loss = outputs.loss\n",
    "    return loss\n",
    "\n",
    "grad_fn = value_and_grad(forward_fn, model.trainable_params())\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.set_train()\n",
    "    total_loss = 0\n",
    "    train_total_size = train_dataset.get_dataset_size()\n",
    "    for step, batch in enumerate(tqdm(train_dataset.create_dict_iterator(), total=train_total_size)):\n",
    "        optimizer.zero_grad()\n",
    "        loss = grad_fn(**batch)\n",
    "        optimizer.step()\n",
    "        total_loss += loss.float()\n",
    "        lr_scheduler.step()\n",
    "\n",
    "    model.set_train(False)\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    eval_total_size = eval_dataset.get_dataset_size()\n",
    "    for step, batch in enumerate(tqdm(eval_dataset.create_dict_iterator(), total=eval_total_size)):\n",
    "        with mindspore._no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(ops.argmax(outputs.logits, -1).asnumpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataset)\n",
    "    eval_ppl = ops.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataset)\n",
    "    train_ppl = ops.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6cafa67b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy=97.34513274336283 % on the evaluation dataset\n",
      "eval_preds[:10]=['neutral', 'neutral', 'neutral', 'neutral', 'positive', 'positive', 'neutral', 'positive', 'positive', 'positive']\n",
      "ground_truth[:10]=['neutral', 'neutral', 'neutral', 'neutral', 'positive', 'neutral', 'neutral', 'positive', 'positive', 'positive']\n"
     ]
    }
   ],
   "source": [
    "# print accuracy\n",
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "ground_truth = []\n",
    "\n",
    "for pred, data in zip(eval_preds, validation_dataset.create_dict_iterator(output_numpy=True)):\n",
    "    true = str(data['text_label'])\n",
    "    ground_truth.append(true)\n",
    "    if pred.strip() == true.strip():\n",
    "        correct += 1\n",
    "    total += 1\n",
    "accuracy = correct / total * 100\n",
    "print(f\"{accuracy=} % on the evaluation dataset\")\n",
    "print(f\"{eval_preds[:10]=}\")\n",
    "print(f\"{ground_truth[:10]=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "26e88eb1-3d24-488b-ad62-0c18b81b861e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Tensor(shape=[8, 128], dtype=Int64, value=\n",
       " [[   486,   2733,   6396 ...      0,      0,      0],\n",
       "  [   486,    259, 150106 ...      0,      0,      0],\n",
       "  [   486,   5835,    259 ...      0,      0,      0],\n",
       "  ...\n",
       "  [  1385,    339,    259 ...      0,      0,      0],\n",
       "  [   259,  18948,    776 ...      0,      0,      0],\n",
       "  [   486,    259,  20147 ...      0,      0,      0]]),\n",
       " Tensor(shape=[8, 128], dtype=Int64, value=\n",
       " [[1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  ...\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0],\n",
       "  [1, 1, 1 ... 0, 0, 0]]),\n",
       " Tensor(shape=[8, 3], dtype=Int64, value=\n",
       " [[59006,     1,  -100],\n",
       "  [59006,     1,  -100],\n",
       "  [59006,     1,  -100],\n",
       "  ...\n",
       "  [59006,     1,  -100],\n",
       "  [59006,     1,  -100],\n",
       "  [18205,     1,  -100]])]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(eval_dataset.create_tuple_iterator())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a8de6005",
   "metadata": {},
   "outputs": [],
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "model.save_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bd20cd4c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.1M\tbigscience/mt0-large_LORA_SEQ_2_SEQ_LM/adapter_model.ckpt\n"
     ]
    }
   ],
   "source": [
    "ckpt = f\"{peft_model_id}/adapter_model.ckpt\"\n",
    "!du -h $ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "76c2fc29",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp.peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "37d712ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "neutral\n",
      "{'input_ids': Tensor(shape=[1, 2], dtype=Int64, value=\n",
      "[[59006,     1]]), 'attention_mask': Tensor(shape=[1, 2], dtype=Int64, value=\n",
      "[[1, 1]])}\n",
      "[[    0 59006     1]]\n",
      "['neutral']\n"
     ]
    }
   ],
   "source": [
    "model.set_train(False)\n",
    "example = next(validation_dataset.create_dict_iterator(output_numpy=True))\n",
    "\n",
    "print(example['text_label'])\n",
    "inputs = tokenizer(example['text_label'], return_tensors=\"ms\")\n",
    "print(inputs)\n",
    "\n",
    "with mindspore._no_grad():\n",
    "    outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.asnumpy(), skip_special_tokens=True))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
